{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {},
  "cells": [
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "   <a target=\"_blank\" href=\"https://labelbox.com\" ><img src=\"https://labelbox.com/blog/content/images/2021/02/logo-v4.svg\" width=256/></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "<td>\n",
        "<a href=\"https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/custom_embeddings.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
        "</td>\n",
        "\n",
        "<td>\n",
        "<a href=\"https://github.com/Labelbox/labelbox-python/tree/master/examples/basics/custom_embeddings.ipynb\" target=\"_blank\"><img\n",
        "src=\"https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white\" alt=\"GitHub\"></a>\n",
        "</td>"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Documentation\n",
        "Please read this document before getting started.\n",
        "https://docs.google.com/document/d/1C_zZFGNjXq10P1MvEX6MM0TC7HHrkFOp9BB0P_S_2MQ"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Imports"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "!pip3 install -q \"labelbox[data]\""
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "import labelbox as lb\n",
        "import numpy as np\n",
        "import json"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Install the wheel from Github"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "!pip3 install -q 'git+https://github.com/Labelbox/advlib.git'"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Labelbox Credentials"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "API_KEY = \"\"\n",
        "client = lb.Client(API_KEY)\n",
        "\n",
        "# set LABELBOX_API_KEY in bash\n",
        "%env LABELBOX_API_KEY=$API_KEY\n",
        "# sanity check it worked\n",
        "!echo $LABELBOX_API_KEY"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Select data rows in Labelbox for custom embeddings"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "client.enable_experimental = True\n",
        "\n",
        "# get images from a Labelbox dataset\n",
        "# Our systems start to process data after 1000 embeddings of each type, for this demo make sure your dataset is over 1000 data rows\n",
        "dataset = client.get_dataset(\"<ADD YOUR DATASET ID>\")\n",
        "\n",
        "export_task = dataset.export()\n",
        "export_task.wait_till_done()"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "data_rows = []\n",
        "\n",
        "def json_stream_handler(output: lb.JsonConverterOutput):\n",
        "  data_row = json.loads(output.json_str)\n",
        "  data_rows.append(data_row)\n",
        "\n",
        "if export_task.has_errors():\n",
        "  export_task.get_stream(\n",
        "  converter=lb.JsonConverter(),\n",
        "  stream_type=lb.StreamType.ERRORS\n",
        "  ).start(stream_handler=lambda error: print(error))\n",
        "\n",
        "if export_task.has_result():\n",
        "  export_json = export_task.get_stream(\n",
        "    converter=lb.JsonConverter(),\n",
        "    stream_type=lb.StreamType.RESULT\n",
        "  ).start(stream_handler=json_stream_handler)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "data_row_ids = [dr[\"data_row\"][\"id\"] for dr in data_rows]\n",
        "\n",
        "data_row_ids = data_row_ids[:1000] # keep the first 1000 examples for the sake of this demo"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Create the payload for custom embeddings\n",
        "It should be a .ndjson file.   \n",
        "Every line is a json file that finishes with a \\n character.  \n",
        "It does not have to be created through python.  "
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "nb_data_rows = len(data_row_ids)\n",
        "print(\"Number of data rows: \", nb_data_rows)\n",
        "# generate random vectors, of dimension 2048 each\n",
        "# Labelbox supports custom embedding vectors of dimension up to 2048\n",
        "custom_embeddings = [list(np.random.random(2048)) for _ in range(nb_data_rows)]"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# create the payload for custom embeddings\n",
        "payload = []\n",
        "for data_row_id,custom_embedding in zip(data_row_ids,custom_embeddings):\n",
        "  payload.append({\"id\": data_row_id, \"vector\": custom_embedding})\n",
        "\n",
        "print('payload', len(payload),payload[:1])"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# delete any pre-existing file\n",
        "import os\n",
        "if os.path.exists(\"payload.ndjson\"):\n",
        "  os.remove(\"payload.ndjson\")\n",
        "\n",
        "# convert the payload to json file\n",
        "with open('payload.ndjson', 'w') as f:\n",
        "  for p in payload:\n",
        "    f.write(json.dumps(p) + \"\\n\")\n",
        "    # sanity_check_payload = json.dump(payload, f)"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# sanity check that you can read/load the file and the payload is correct\n",
        "with open('payload.ndjson') as f:\n",
        "    sanity_check_payload = [json.loads(l) for l in f.readlines()]\n",
        "print(\"Nb of custom embedding vectors in sanity_check_payload: \", len(sanity_check_payload))"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# See all custom embeddings available in your Labelbox workspace\n",
        "!advtool embeddings list"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# # Create a new custom embedding, unless you want to re-use one\n",
        "!advtool embeddings create my_custom_embedding_2048_dimensions 2048\n",
        "# this command will return the ID of the newly created embedding, e.g. ciqtgd94607290000ljx4dvh2"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# # Delete a custom embedding\n",
        "# !advtool embeddings delete ciqtgd94607290000ljx4dvh2"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Upload the payload to Labelbox"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# Replace the current id with the newly generated id from the previous step, or any existing custom embedding id.\n",
        "!advtool embeddings import c933bviqn0756000elk07et77 ./payload.ndjson"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "# Pick an existing custom embedding, or create a custom embedding"
      ],
      "cell_type": "markdown"
    },
    {
      "metadata": {},
      "source": [
        "# count how many data rows have a specific custom embedding (This can take a couple of minutes)\n",
        "!advtool embeddings count c933bviqn0756000elk07et77"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    },
    {
      "metadata": {},
      "source": [
        "print(len(payload))"
      ],
      "cell_type": "code",
      "outputs": [],
      "execution_count": null
    }
  ]
}